{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# 生成随机数"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/random_numbers\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/random_numbers.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" class=\"\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/random_numbers.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" class=\"\">在 GitHub 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/random_numbers.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" class=\"\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BlGY1iiph_C2"
      },
      "source": [
        "TensorFlow 在 `tf.random` 模块中提供了一组伪随机数生成器 (RNG)。本文介绍如何控制随机数生成器，以及这些生成器如何与其他 Tensorflow 子系统交互。\n",
        "\n",
        "TensorFlow 提供了两种方法来控制随机数生成过程：\n",
        "\n",
        "1. 通过明确使用 `tf.random.Generator` 对象。每个此类对象都会在 `tf.Variable` 中维护一个状态，该状态在每次生成随机数后都会发生改变。\n",
        "\n",
        "2. 通过使用纯函数式无状态随机函数，如 `tf.random.stateless_uniform`。在同一设备上调用具有相同参数（包括种子）的这些函数会产生相同的结果。\n",
        "\n",
        "警告：目前尚未弃用 TF 1.x 中的旧版 RNG（如 `tf.random.uniform` 和 `tf.random.normal`），但强烈建议不要使用。\n",
        "\n",
        "警告：不保证随机数在不同 TensorFlow 版本间一致，请参阅：[版本兼容性](https://tensorflow.google.cn/guide/versions#what_is_not_covered)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zIGh9faCOp6x"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ECDrttf0s8Nu"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "# Creates 2 virtual devices cpu:0 and cpu:1 for using distribution strategy\n",
        "physical_devices = tf.config.experimental.list_physical_devices(\"CPU\")\n",
        "tf.config.experimental.set_virtual_device_configuration(\n",
        "    physical_devices[0], [\n",
        "        tf.config.experimental.VirtualDeviceConfiguration(),\n",
        "        tf.config.experimental.VirtualDeviceConfiguration()\n",
        "    ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eqMlrUsVu2Ai"
      },
      "source": [
        "## `tf.random.Generator` 类\n",
        "\n",
        "当您希望每次调用 RNG 都产生不同的结果时，可以使用 `tf.random.Generator` 类。它会维护一个内部状态（由 `tf.Variable` 对象管理），该状态在每次生成随机数时都会更新。由于该状态由 `tf.Variable` 管理，因此，它可以利用 `tf.Variable` 提供的所有功能，如简单的检查点、自动控制依赖项和线程安全性。\n",
        "\n",
        "通过手动创建 `tf.random.Generator`类的一个对象，您可以获得该生成器，或者通过调用 `tf.random.get_global_generator()`，您可以获得默认全局生成器："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7yU1E3JvxOQD"
      },
      "outputs": [],
      "source": [
        "g1 = tf.random.Generator.from_seed(1)\n",
        "print(g1.normal(shape=[2, 3]))\n",
        "g2 = tf.random.get_global_generator()\n",
        "print(g2.normal(shape=[2, 3]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QmRCeAvTxulW"
      },
      "source": [
        "有多种方法可以创建生成器对象。最简单的方法是使用 `Generator.from_seed`（代码如上），从种子创建生成器。种子可以是任何非负整数，`from_seed` 还有一个可选参数 `alg`，这是该生成器将使用的 RNG 算法。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kISbOE4Xfjhv"
      },
      "outputs": [],
      "source": [
        "g1 = tf.random.Generator.from_seed(1, alg='philox')\n",
        "print(g1.normal(shape=[2, 3]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_mCRaN7dfd8j"
      },
      "source": [
        "有关详细信息，请参阅后文中的*算法*部分。\n",
        "\n",
        "创建生成器的另一种方法是使用 `Generator.from_non_deterministic_state`。以这种方式创建的生成器首先会处于非确定状态，具体取决于时间和操作系统等因素。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gxPLCLsz00qY"
      },
      "outputs": [],
      "source": [
        "g = tf.random.Generator.from_non_deterministic_state()\n",
        "print(g.normal(shape=[2, 3]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zSAp2BMj1JZ6"
      },
      "source": [
        "还有其他方法可以创建生成器，比如说通过显式状态创建，本指南不作赘述。\n",
        "\n",
        "当使用 `tf.random.get_global_generator` 来获取全局生成器时，需要注意设备放置。第一次调用 `tf.random.get_global_generator` 时就会创建全局生成器（从非确定状态），并将其放置在该调用的作用域内的默认设备上。举个例子，如果第一次调用 `tf.random.get_global_generator` 的位置在 `tf.device(\"gpu\")` 作用域内，则会将全局生成器放置在 GPU 上，如果稍后要从 CPU 使用全局生成器，则会将其从 GPU 复制到 CPU。\n",
        "\n",
        "另外还有一个函数 `tf.random.set_global_generator`，可用于将全局生成器替换为另一个生成器对象。使用该函数前要三思，因为 `tf.function` 可能已获得旧全局生成器（作为弱引用），所以，替换它可能导致它被回收，从而中断 `tf.function`。有一种更好的方法可以重置全局生成器，即使用一个“重置”函数（如 `Generator.reset_from_seed`），这样就不会创建新的生成器对象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "324S5bpd9HRg"
      },
      "outputs": [],
      "source": [
        "g = tf.random.Generator.from_seed(1)\n",
        "print(g.normal([]))\n",
        "print(g.normal([]))\n",
        "g.reset_from_seed(1)\n",
        "print(g.normal([]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z9H0wuvp9VwH"
      },
      "source": [
        "### 创建独立的随机数流\n",
        "\n",
        "许多应用都需要多个独立的随机数流，所谓独立，就是指不能相互重叠，也不能有统计学上可检测到的相关性。通过使用 `Generator.split` 创建多个一定相互独立的生成器即可实现此目的（即生成独立流）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vg5_KN18OZjo"
      },
      "outputs": [],
      "source": [
        "g = tf.random.Generator.from_seed(1)\n",
        "print(g.normal([]))\n",
        "new_gs = g.split(3)\n",
        "for new_g in new_gs:\n",
        "  print(new_g.normal([]))\n",
        "print(g.normal([]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dqOaGVzKOsRJ"
      },
      "source": [
        "与 `normal` 之类的 RNG 方法类似，`split` 会改变调用它的生成器的状态（上例中为 `g`）。除相互之间保持独立外，新生成器 (`new_gs`) 还一定独立于旧生成器 (`g`)。\n",
        "\n",
        "当您想要确保使用的生成器位于与其他计算相同的设备上，从而避免跨设备复制的开销时，生成新生成器也很有用。例如： "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5jSnJBlUQzF3"
      },
      "outputs": [],
      "source": [
        "with tf.device(\"cpu\"):  # change \"cpu\" to the device you want\n",
        "  g = tf.random.get_global_generator().split(1)[0]  \n",
        "  print(g.normal([]))  # use of g won't cause cross-device copy, unlike the global generator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sCxbccYMRdd4"
      },
      "source": [
        "注：在理论上，此处可以使用 `from_seed`（而不是 `split`）之类的构造函数获取新生成器，但这样做无法保证新生成器与全局生成器相互独立。同时也有使用同一种子或导致产生重叠随机数流的种子意外创建两个生成器的风险。\n",
        "\n",
        "您可以在拆分的生成器上调用 `split`，从而以递归方式执行拆分。递归深度没有限制（除非发生整数溢出）。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8JUgnQM_O0lg"
      },
      "source": [
        "### 与 `tf.function` 交互\n",
        "\n",
        "与 `tf.function` 一起使用时，`tf.random.Generator` 遵循与 `tf.Variable` 相同的原则。这包括三个方面："
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jnSjhY6WM-J8"
      },
      "source": [
        "#### 在 `tf.function` 的外部创建生成器\n",
        "\n",
        "`tf.function` 可以使用在其外部创建的生成器。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a5EEy0E2UHMw"
      },
      "outputs": [],
      "source": [
        "g = tf.random.Generator.from_seed(1)\n",
        "@tf.function\n",
        "def foo():\n",
        "  return g.normal([])\n",
        "print(foo())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L_8kC7kbO5uu"
      },
      "source": [
        "调用该函数时，用户需要确保生成器对象仍处于活动状态（没有被回收）。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PwIrBv_zUYwI"
      },
      "source": [
        "#### 在 `tf.function` 的内部创建生成器\n",
        "\n",
        "只有 `tf.function` 第一次运行时，才可以在其内部创建生成器。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3JzpUvqJU4MW"
      },
      "outputs": [],
      "source": [
        "g = None\n",
        "@tf.function\n",
        "def foo():\n",
        "  global g\n",
        "  if g is None:\n",
        "    g = tf.random.Generator.from_seed(1)\n",
        "  return g.normal([])\n",
        "print(foo())\n",
        "print(foo())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UaTVnOhHVM9a"
      },
      "source": [
        "#### 将生成器作为参数传递给 `tf.function`\n",
        "\n",
        "当用作 `tf.function` 的参数时，具有相同状态大小（状态大小由 RNG 算法确定）的不同生成器对象不会导致回溯 `tf.function`，而具有不同状态大小的不同生成器对象则会导致回溯。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DeR9kvt0V-ad"
      },
      "outputs": [],
      "source": [
        "num_traces = 0\n",
        "@tf.function\n",
        "def foo(g):\n",
        "  global num_traces\n",
        "  num_traces += 1\n",
        "  return g.normal([])\n",
        "foo(tf.random.Generator.from_seed(1))\n",
        "foo(tf.random.Generator.from_seed(2))\n",
        "print(num_traces)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fxcS6IY8WZuh"
      },
      "source": [
        "### 与分布策略交互\n",
        "\n",
        "`Generator` 与分布策略有三种交互方式。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GyZv9QJkZfkQ"
      },
      "source": [
        "#### 在分布策略的外部创建生成器\n",
        "\n",
        "如果是在策略作用域的外部创建的生成器，则会序列化访问此生成器的所有副本，因此，每一个副本都会得到不同的随机数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HX_beT9SZWMp"
      },
      "outputs": [],
      "source": [
        "g = tf.random.Generator.from_seed(1)\n",
        "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n",
        "with strat.scope():\n",
        "  def f():\n",
        "    print(g.normal([]))\n",
        "  results = strat.run(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ydYQbUqLPAgH"
      },
      "source": [
        "请注意，这种使用方法可能产生性能问题，因为生成器的设备与副本不同。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yal4LbBKbAeN"
      },
      "source": [
        "#### 在分布策略的内部创建生成器\n",
        "\n",
        "不允许在策略作用域内部创建生成器，因为这会导致在如何复制生成器方面出现歧义。比方说，是应该复制生成器，从而让每一个副本都获得相同的随机数，还是应该“拆分”，从而让每一个副本获得不同的随机数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6McVq-gbK_d"
      },
      "outputs": [],
      "source": [
        "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n",
        "with strat.scope():\n",
        "  try:\n",
        "    tf.random.Generator.from_seed(1)\n",
        "  except ValueError as e:\n",
        "    print(\"ValueError:\", e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pqQfWkMWQnnI"
      },
      "source": [
        "请注意，`Strategy.run` 会在策略作用域内隐式运行参数函数："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X6Ceqha3RKKo"
      },
      "outputs": [],
      "source": [
        "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n",
        "def f():\n",
        "  tf.random.Generator.from_seed(1)\n",
        "try:\n",
        "  strat.run(f)\n",
        "except ValueError as e:\n",
        "  print(\"ValueError:\", e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODLS8njzbUEF"
      },
      "source": [
        "#### 将生成器作为参数传递给 `Strategy.run`\n",
        "\n",
        "如果您希望每个副本都使用自己的生成器，则需要通过复制或拆分创建 `n`（`n` 表示副本数量）个生成器，然后将其作为参数传递给 `Strategy.run`。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YurAsX3nbROP"
      },
      "outputs": [],
      "source": [
        "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n",
        "gs = tf.random.get_global_generator().split(2)\n",
        "# to_args is a workaround for the absence of APIs to create arguments for \n",
        "# run. It will be replaced when such APIs are available.\n",
        "def to_args(gs):  \n",
        "  with strat.scope():\n",
        "    def f():\n",
        "      return [gs[tf.distribute.get_replica_context().replica_id_in_sync_group]]\n",
        "    return strat.run(f)\n",
        "args = to_args(gs)\n",
        "def f(g):\n",
        "  print(g.normal([]))\n",
        "results = strat.run(f, args=args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "73an1POpsi6V"
      },
      "source": [
        "## 无状态 RNG\n",
        "\n",
        "无状态 RNG 的使用方法非常简单。因为它们是纯函数，不涉及状态或副作用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0-aOOA3gasn_"
      },
      "outputs": [],
      "source": [
        "print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))\n",
        "print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2O_D-RAFNH2Q"
      },
      "source": [
        "每个无状态 RNG 都需要一个 `seed` 参数，该参数必须是形状为 `[2]` 的整数张量。该运算的结果完全由种子确定。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4BvGkPnaOUPF"
      },
      "source": [
        "## 算法"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "58-8kvR4pRwO"
      },
      "source": [
        "### 基本信息\n",
        "\n",
        "`tf.random.Generator` 类和 `stateless` 函数在所有设备上都支持 Philox 算法（写作 `\"philox\"` 或 `tf.random.Algorithm.PHILOX`）。\n",
        "\n",
        "如果使用相同的算法且从相同的状态开始，则不同的设备会生成相同的整数。它们还可以生成“几乎相同”的浮点数，虽然由于设备执行浮点计算的方式不同（如降阶），数值可能存在微小的差异。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WETA04F1OYPL"
      },
      "source": [
        "### XLA 设备\n",
        "\n",
        "在 XLA 驱动的设备（如 TPU 以及启用 XLA 时的 CPU/GPU）上，还支持 ThreeFry 算法（写作 `\"threefry\"` 或 `tf.random.Algorithm.THREEFRY`）。与 Philox 算法相比，该算法在 TPU 上执行速度较快，而在 CPU/GPU 上执行速度较慢。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c04JkebCPTPu"
      },
      "source": [
        "有关这些算法的更多详细信息，请参阅论文[“Parallel Random Numbers: As Easy as 1, 2, 3”](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "random_numbers.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
